Skip to content

Add support for float mask to aten::masked_fill #1337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gpetters94
Copy link
Collaborator

This adds a few small changes that are needed for OPT support, namely:

  • Support for aten::view when the output shape is statically known

  • Folding away torch::type_as when both arguments are the same type

  • Support for torch::masked_fill when the mask is a float type

@ramiro050
Copy link
Collaborator

Can you split the changes into 3 PRs? They are all quite independent from one another. It would also make the commit titles a lot more descriptive, since each PR could get as a title the description you have in the bullet points above.

@gpetters94 gpetters94 changed the title Add cases for view, type_as, and masked_fill Add support for float mask to aten::masked_fill Sep 9, 2022
@gpetters94
Copy link
Collaborator Author

I've split the PRs, this one is for masked_fill now.

@gpetters94 gpetters94 requested a review from ramiro050 September 9, 2022 17:40
Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just have a small change request, but other than that it LGTM

@gpetters94 gpetters94 force-pushed the masked_fill branch 3 times, most recently from 003853f to f73493e Compare September 14, 2022 16:15
@@ -906,6 +906,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(

Value input = payloadArgs[0];
Value mask = payloadArgs[1];
if (mask.getType().isa<mlir::FloatType>())
mask = b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I missed this the first time I reviewed your changes. Why is mask being set to false here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be expected behavior. I was casting to Int1 at first, but further testing shows that it seems to treat all floats as false. I haven't found anything in the documentation about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug upstream. I would actually expect float mask to result in a runtime error, since this is the behavior that aten.masked_select has:

https://github.com/pytorch/pytorch/blob/5b58140d1a471b144baf66cc61a45a86746f0215/aten/src/ATen/native/TensorAdvancedIndexing.cpp#L1720-L1721

Can you file a bug upstream for this?

Copy link
Collaborator Author

@gpetters94 gpetters94 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the bug report is up here. I'm going to leave this as-is for now in case it's expected behavior but I'll add an assert if it isn't.

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this pull request Oct 3, 2022
Rewrite the deprecated ONNX Scatter operation (https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scatter) using the equivalent ScatterElements operation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants